{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 前言\n",
    "\n",
    "原文翻译自：[Deep Learning with PyTorch: A 60 Minute Blitz](https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html)\n",
    "\n",
    "翻译：[林不清](https://www.zhihu.com/people/lu-guo-92-42-88)\n",
    "\n",
    "整理：机器学习初学者公众号 ![gongzhong](images/gongzhong.jpg)\n",
    "\n",
    "## 目录\n",
    "\n",
    "[60分钟入门PyTorch（一）——Tensors](https://zhuanlan.zhihu.com/p/347676809)\n",
    "\n",
    "[60分钟入门PyTorch（二）——Autograd自动求导](https://zhuanlan.zhihu.com/p/347672836)\n",
    "\n",
    "[60分钟入门Pytorch（三）——神经网络](https://zhuanlan.zhihu.com/p/347678492)\n",
    "\n",
    "[60分钟入门PyTorch（四）——训练一个分类器](https://zhuanlan.zhihu.com/p/347681137)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练一个分类器"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "你已经学会如何去定义一个神经网络,计算损失值和更新网络的权重。\n",
    "\n",
    "你现在可能在思考：数据哪里来呢？\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 关于数据"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "通常，当你处理图像，文本，音频和视频数据时，你可以使用标准的Python包来加载数据到一个numpy数组中.然后把这个数组转换成`torch.*Tensor`。\n",
    "* 对于图像,有诸如Pillow,OpenCV包等非常实用\n",
    "* 对于音频,有诸如scipy和librosa包\n",
    "* 对于文本,可以用原始Python和Cython来加载,或者使用NLTK和SpaCy\n",
    "对于视觉,我们创建了一个`torchvision`包,包含常见数据集的数据加载,比如Imagenet,CIFAR10,MNIST等,和图像转换器,也就是`torchvision.datasets`和`torch.utils.data.DataLoader`。\n",
    "\n",
    "这提供了巨大的便利,也避免了代码的重复。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在这个教程中,我们使用CIFAR10数据集,它有如下10个类别:’airplane’,’automobile’,’bird’,’cat’,’deer’,’dog’,’frog’,’horse’,’ship’,’truck’。这个数据集中的图像大小为3\\*32\\*32,即,3通道,32\\*32像素。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "![cifar10](images/cifar10.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 训练一个图像分类器"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们将按照下列顺序进行:\n",
    "* 使用`torchvision`加载和归一化CIFAR10训练集和测试集.\n",
    "* 定义一个卷积神经网络\n",
    "* 定义损失函数\n",
    "* 在训练集上训练网络\n",
    "* 在测试集上测试网络\n",
    "\n",
    "### 1. 加载和归一化CIFAR10\n",
    "使用`torchvision`加载CIFAR10是非常容易的。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torchvision\n",
    "import torchvision.transforms as transforms"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "torchvision的输出是[0,1]的PILImage图像,我们把它转换为归一化范围为[-1, 1]的张量。\n",
    "<div class=\"alert alert-info\"><h4>注意</h4><p>如果在Windows上运行时出现BrokenPipeError，尝试将torch.utils.data.DataLoader()的num_worker设置为0。</p></div>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\\cifar-10-python.tar.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dfeaa5d49cb44bc0a927ab4196629477",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./data\\cifar-10-python.tar.gz\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "2a44cd06f498404cab0bfb8a0f762c68",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "URLError",
     "evalue": "<urlopen error [Errno 11001] getaddrinfo failed>",
     "output_type": "error",
     "traceback": [
      "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[1;31mConnectionAbortedError\u001b[0m                    Traceback (most recent call last)",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\utils.py\u001b[0m in \u001b[0;36mdownload_url\u001b[1;34m(url, root, filename, md5)\u001b[0m\n\u001b[0;32m     70\u001b[0m                 \u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfpath\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 71\u001b[1;33m                 \u001b[0mreporthook\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mgen_bar_updater\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     72\u001b[0m             )\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\urllib\\request.py\u001b[0m in \u001b[0;36murlretrieve\u001b[1;34m(url, filename, reporthook, data)\u001b[0m\n\u001b[0;32m    275\u001b[0m             \u001b[1;32mwhile\u001b[0m \u001b[1;32mTrue\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 276\u001b[1;33m                 \u001b[0mblock\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    277\u001b[0m                 \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mblock\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\http\\client.py\u001b[0m in \u001b[0;36mread\u001b[1;34m(self, amt)\u001b[0m\n\u001b[0;32m    456\u001b[0m             \u001b[0mb\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mbytearray\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mamt\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 457\u001b[1;33m             \u001b[0mn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreadinto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    458\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mmemoryview\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[0mn\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtobytes\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\http\\client.py\u001b[0m in \u001b[0;36mreadinto\u001b[1;34m(self, b)\u001b[0m\n\u001b[0;32m    500\u001b[0m         \u001b[1;31m# (for example, reading in 1k chunks)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 501\u001b[1;33m         \u001b[0mn\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mreadinto\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    502\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mn\u001b[0m \u001b[1;32mand\u001b[0m \u001b[0mb\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\socket.py\u001b[0m in \u001b[0;36mreadinto\u001b[1;34m(self, b)\u001b[0m\n\u001b[0;32m    588\u001b[0m             \u001b[1;32mtry\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 589\u001b[1;33m                 \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sock\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mrecv_into\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mb\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    590\u001b[0m             \u001b[1;32mexcept\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\ssl.py\u001b[0m in \u001b[0;36mrecv_into\u001b[1;34m(self, buffer, nbytes, flags)\u001b[0m\n\u001b[0;32m   1070\u001b[0m                   self.__class__)\n\u001b[1;32m-> 1071\u001b[1;33m             \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mnbytes\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbuffer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1072\u001b[0m         \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\ssl.py\u001b[0m in \u001b[0;36mread\u001b[1;34m(self, len, buffer)\u001b[0m\n\u001b[0;32m    928\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mbuffer\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 929\u001b[1;33m                 \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_sslobj\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mread\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbuffer\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    930\u001b[0m             \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mConnectionAbortedError\u001b[0m: [WinError 10053] 你的主机中的软件中止了一个已建立的连接。",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[1;31mgaierror\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\urllib\\request.py\u001b[0m in \u001b[0;36mdo_open\u001b[1;34m(self, http_class, req, **http_conn_args)\u001b[0m\n\u001b[0;32m   1318\u001b[0m                 h.request(req.get_method(), req.selector, req.data, headers,\n\u001b[1;32m-> 1319\u001b[1;33m                           encode_chunked=req.has_header('Transfer-encoding'))\n\u001b[0m\u001b[0;32m   1320\u001b[0m             \u001b[1;32mexcept\u001b[0m \u001b[0mOSError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0merr\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# timeout error\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\http\\client.py\u001b[0m in \u001b[0;36mrequest\u001b[1;34m(self, method, url, body, headers, encode_chunked)\u001b[0m\n\u001b[0;32m   1251\u001b[0m         \u001b[1;34m\"\"\"Send a complete request to the server.\"\"\"\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1252\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_send_request\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmethod\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mbody\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mheaders\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mencode_chunked\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1253\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\http\\client.py\u001b[0m in \u001b[0;36m_send_request\u001b[1;34m(self, method, url, body, headers, encode_chunked)\u001b[0m\n\u001b[0;32m   1297\u001b[0m             \u001b[0mbody\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_encode\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbody\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;34m'body'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1298\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mendheaders\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mbody\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mencode_chunked\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mencode_chunked\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1299\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\http\\client.py\u001b[0m in \u001b[0;36mendheaders\u001b[1;34m(self, message_body, encode_chunked)\u001b[0m\n\u001b[0;32m   1246\u001b[0m             \u001b[1;32mraise\u001b[0m \u001b[0mCannotSendHeader\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1247\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_send_output\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmessage_body\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mencode_chunked\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mencode_chunked\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1248\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\http\\client.py\u001b[0m in \u001b[0;36m_send_output\u001b[1;34m(self, message_body, encode_chunked)\u001b[0m\n\u001b[0;32m   1025\u001b[0m         \u001b[1;32mdel\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_buffer\u001b[0m\u001b[1;33m[\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1026\u001b[1;33m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msend\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1027\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\http\\client.py\u001b[0m in \u001b[0;36msend\u001b[1;34m(self, data)\u001b[0m\n\u001b[0;32m    965\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mauto_open\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 966\u001b[1;33m                 \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mconnect\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    967\u001b[0m             \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\http\\client.py\u001b[0m in \u001b[0;36mconnect\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    937\u001b[0m         self.sock = self._create_connection(\n\u001b[1;32m--> 938\u001b[1;33m             (self.host,self.port), self.timeout, self.source_address)\n\u001b[0m\u001b[0;32m    939\u001b[0m         \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msock\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0msetsockopt\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0msocket\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mIPPROTO_TCP\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msocket\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mTCP_NODELAY\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m1\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\socket.py\u001b[0m in \u001b[0;36mcreate_connection\u001b[1;34m(address, timeout, source_address)\u001b[0m\n\u001b[0;32m    706\u001b[0m     \u001b[0merr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 707\u001b[1;33m     \u001b[1;32mfor\u001b[0m \u001b[0mres\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mgetaddrinfo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhost\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mport\u001b[0m\u001b[1;33m,\u001b[0m \u001b[1;36m0\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mSOCK_STREAM\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    708\u001b[0m         \u001b[0maf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msocktype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mproto\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcanonname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msa\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mres\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\socket.py\u001b[0m in \u001b[0;36mgetaddrinfo\u001b[1;34m(host, port, family, type, proto, flags)\u001b[0m\n\u001b[0;32m    751\u001b[0m     \u001b[0maddrlist\u001b[0m \u001b[1;33m=\u001b[0m \u001b[1;33m[\u001b[0m\u001b[1;33m]\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 752\u001b[1;33m     \u001b[1;32mfor\u001b[0m \u001b[0mres\u001b[0m \u001b[1;32min\u001b[0m \u001b[0m_socket\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgetaddrinfo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhost\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mport\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfamily\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mproto\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mflags\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    753\u001b[0m         \u001b[0maf\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msocktype\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mproto\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mcanonname\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0msa\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mres\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mgaierror\u001b[0m: [Errno 11001] getaddrinfo failed",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001b[1;31mURLError\u001b[0m                                  Traceback (most recent call last)",
      "\u001b[1;32m<ipython-input-3-083e3b634132>\u001b[0m in \u001b[0;36m<module>\u001b[1;34m\u001b[0m\n\u001b[0;32m      4\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m      5\u001b[0m trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n\u001b[1;32m----> 6\u001b[1;33m                                         download=True, transform=transform)\n\u001b[0m\u001b[0;32m      7\u001b[0m trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,\n\u001b[0;32m      8\u001b[0m                                           shuffle=True, num_workers=2)\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\cifar.py\u001b[0m in \u001b[0;36m__init__\u001b[1;34m(self, root, train, transform, target_transform, download)\u001b[0m\n\u001b[0;32m     56\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     57\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mdownload\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 58\u001b[1;33m             \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdownload\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     59\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m     60\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_check_integrity\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\cifar.py\u001b[0m in \u001b[0;36mdownload\u001b[1;34m(self)\u001b[0m\n\u001b[0;32m    134\u001b[0m             \u001b[0mprint\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;34m'Files already downloaded and verified'\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    135\u001b[0m             \u001b[1;32mreturn\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 136\u001b[1;33m         \u001b[0mdownload_and_extract_archive\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mroot\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mfilename\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtgz_md5\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    137\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    138\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mextra_repr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\utils.py\u001b[0m in \u001b[0;36mdownload_and_extract_archive\u001b[1;34m(url, download_root, extract_root, filename, md5, remove_finished)\u001b[0m\n\u001b[0;32m    247\u001b[0m         \u001b[0mfilename\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mbasename\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0murl\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    248\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 249\u001b[1;33m     \u001b[0mdownload_url\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdownload_root\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmd5\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    250\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    251\u001b[0m     \u001b[0marchive\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mos\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mpath\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mjoin\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mdownload_root\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfilename\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\site-packages\\torchvision\\datasets\\utils.py\u001b[0m in \u001b[0;36mdownload_url\u001b[1;34m(url, root, filename, md5)\u001b[0m\n\u001b[0;32m     78\u001b[0m                 urllib.request.urlretrieve(\n\u001b[0;32m     79\u001b[0m                     \u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mfpath\u001b[0m\u001b[1;33m,\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m---> 80\u001b[1;33m                     \u001b[0mreporthook\u001b[0m\u001b[1;33m=\u001b[0m\u001b[0mgen_bar_updater\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m     81\u001b[0m                 )\n\u001b[0;32m     82\u001b[0m             \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\urllib\\request.py\u001b[0m in \u001b[0;36murlretrieve\u001b[1;34m(url, filename, reporthook, data)\u001b[0m\n\u001b[0;32m    245\u001b[0m     \u001b[0murl_type\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mpath\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0msplittype\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0murl\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    246\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 247\u001b[1;33m     \u001b[1;32mwith\u001b[0m \u001b[0mcontextlib\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mclosing\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0murlopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m)\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mfp\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    248\u001b[0m         \u001b[0mheaders\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0minfo\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    249\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\urllib\\request.py\u001b[0m in \u001b[0;36murlopen\u001b[1;34m(url, data, timeout, cafile, capath, cadefault, context)\u001b[0m\n\u001b[0;32m    220\u001b[0m     \u001b[1;32melse\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    221\u001b[0m         \u001b[0mopener\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0m_opener\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 222\u001b[1;33m     \u001b[1;32mreturn\u001b[0m \u001b[0mopener\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mopen\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0murl\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mtimeout\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    223\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    224\u001b[0m \u001b[1;32mdef\u001b[0m \u001b[0minstall_opener\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mopener\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\urllib\\request.py\u001b[0m in \u001b[0;36mopen\u001b[1;34m(self, fullurl, data, timeout)\u001b[0m\n\u001b[0;32m    523\u001b[0m             \u001b[0mreq\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mmeth\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mreq\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    524\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 525\u001b[1;33m         \u001b[0mresponse\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0m_open\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mreq\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    526\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    527\u001b[0m         \u001b[1;31m# post-process response\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\urllib\\request.py\u001b[0m in \u001b[0;36m_open\u001b[1;34m(self, req, data)\u001b[0m\n\u001b[0;32m    541\u001b[0m         \u001b[0mprotocol\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mreq\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mtype\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    542\u001b[0m         result = self._call_chain(self.handle_open, protocol, protocol +\n\u001b[1;32m--> 543\u001b[1;33m                                   '_open', req)\n\u001b[0m\u001b[0;32m    544\u001b[0m         \u001b[1;32mif\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    545\u001b[0m             \u001b[1;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\urllib\\request.py\u001b[0m in \u001b[0;36m_call_chain\u001b[1;34m(self, chain, kind, meth_name, *args)\u001b[0m\n\u001b[0;32m    501\u001b[0m         \u001b[1;32mfor\u001b[0m \u001b[0mhandler\u001b[0m \u001b[1;32min\u001b[0m \u001b[0mhandlers\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    502\u001b[0m             \u001b[0mfunc\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhandler\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mmeth_name\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m--> 503\u001b[1;33m             \u001b[0mresult\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m*\u001b[0m\u001b[0margs\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m    504\u001b[0m             \u001b[1;32mif\u001b[0m \u001b[0mresult\u001b[0m \u001b[1;32mis\u001b[0m \u001b[1;32mnot\u001b[0m \u001b[1;32mNone\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m    505\u001b[0m                 \u001b[1;32mreturn\u001b[0m \u001b[0mresult\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\urllib\\request.py\u001b[0m in \u001b[0;36mhttp_open\u001b[1;34m(self, req)\u001b[0m\n\u001b[0;32m   1345\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1346\u001b[0m     \u001b[1;32mdef\u001b[0m \u001b[0mhttp_open\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mself\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreq\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1347\u001b[1;33m         \u001b[1;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdo_open\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0mhttp\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mclient\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mHTTPConnection\u001b[0m\u001b[1;33m,\u001b[0m \u001b[0mreq\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1348\u001b[0m \u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1349\u001b[0m     \u001b[0mhttp_request\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mAbstractHTTPHandler\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mdo_request_\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;32mC:\\ProgramData\\Anaconda3\\lib\\urllib\\request.py\u001b[0m in \u001b[0;36mdo_open\u001b[1;34m(self, http_class, req, **http_conn_args)\u001b[0m\n\u001b[0;32m   1319\u001b[0m                           encode_chunked=req.has_header('Transfer-encoding'))\n\u001b[0;32m   1320\u001b[0m             \u001b[1;32mexcept\u001b[0m \u001b[0mOSError\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0merr\u001b[0m\u001b[1;33m:\u001b[0m \u001b[1;31m# timeout error\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m-> 1321\u001b[1;33m                 \u001b[1;32mraise\u001b[0m \u001b[0mURLError\u001b[0m\u001b[1;33m(\u001b[0m\u001b[0merr\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m\u001b[0;32m   1322\u001b[0m             \u001b[0mr\u001b[0m \u001b[1;33m=\u001b[0m \u001b[0mh\u001b[0m\u001b[1;33m.\u001b[0m\u001b[0mgetresponse\u001b[0m\u001b[1;33m(\u001b[0m\u001b[1;33m)\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m   1323\u001b[0m         \u001b[1;32mexcept\u001b[0m\u001b[1;33m:\u001b[0m\u001b[1;33m\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n",
      "\u001b[1;31mURLError\u001b[0m: <urlopen error [Errno 11001] getaddrinfo failed>"
     ]
    }
   ],
   "source": [
    "transform = transforms.Compose(\n",
    "    [transforms.ToTensor(),\n",
    "     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n",
    "\n",
    "trainset = torchvision.datasets.CIFAR10(root='./data', train=True,\n",
    "                                        download=True, transform=transform)\n",
    "trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,\n",
    "                                          shuffle=True, num_workers=2)\n",
    "\n",
    "testset = torchvision.datasets.CIFAR10(root='./data', train=False,\n",
    "                                       download=True, transform=transform)\n",
    "testloader = torch.utils.data.DataLoader(testset, batch_size=4,\n",
    "                                         shuffle=False, num_workers=2)\n",
    "\n",
    "classes = ('plane', 'car', 'bird', 'cat',\n",
    "           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')\n",
    "#这个过程有点慢，会下载大约340mb图片数据。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们展示一些有趣的训练图像。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "import numpy as np\n",
    "\n",
    "# functions to show an image\n",
    "\n",
    "\n",
    "def imshow(img):\n",
    "    img = img / 2 + 0.5     # unnormalize\n",
    "    npimg = img.numpy()\n",
    "    plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
    "    plt.show()\n",
    "\n",
    "\n",
    "# get some random training images\n",
    "dataiter = iter(trainloader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "# show images\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "# print labels\n",
    "print(' '.join('%5s' % classes[labels[j]] for j in range(4)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 2. 定义一个卷积神经网络\n",
    "从之前的神经网络一节复制神经网络代码,并修改为接受3通道图像取代之前的接受单通道图像。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        self.conv1 = nn.Conv2d(3, 6, 5)\n",
    "        self.pool = nn.MaxPool2d(2, 2)\n",
    "        self.conv2 = nn.Conv2d(6, 16, 5)\n",
    "        self.fc1 = nn.Linear(16 * 5 * 5, 120)\n",
    "        self.fc2 = nn.Linear(120, 84)\n",
    "        self.fc3 = nn.Linear(84, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.pool(F.relu(self.conv1(x)))\n",
    "        x = self.pool(F.relu(self.conv2(x)))\n",
    "        x = x.view(-1, 16 * 5 * 5)\n",
    "        x = F.relu(self.fc1(x))\n",
    "        x = F.relu(self.fc2(x))\n",
    "        x = self.fc3(x)\n",
    "        return x\n",
    "\n",
    "\n",
    "net = Net()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 3. 定义损失函数和优化器\n",
    "我们使用交叉熵作为损失函数,使用带动量的随机梯度下降。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.optim as optim\n",
    "\n",
    "criterion = nn.CrossEntropyLoss()\n",
    "optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 4. 训练网络\n",
    "这是开始有趣的时刻，我们只需在数据迭代器上循环,把数据输入给网络,并优化。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for epoch in range(2):  # loop over the dataset multiple times\n",
    "\n",
    "    running_loss = 0.0\n",
    "    for i, data in enumerate(trainloader, 0):\n",
    "        # get the inputs; data is a list of [inputs, labels]\n",
    "        inputs, labels = data\n",
    "\n",
    "        # zero the parameter gradients\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # forward + backward + optimize\n",
    "        outputs = net(inputs)\n",
    "        loss = criterion(outputs, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # print statistics\n",
    "        running_loss += loss.item()\n",
    "        if i % 2000 == 1999:    # print every 2000 mini-batches\n",
    "            print('[%d, %5d] loss: %.3f' %\n",
    "                  (epoch + 1, i + 1, running_loss / 2000))\n",
    "            running_loss = 0.0\n",
    "\n",
    "print('Finished Training')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "保存一下我们的训练模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "PATH = './cifar_net.pth'\n",
    "torch.save(net.state_dict(), PATH)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "点击[这里](https://pytorch.org/docs/stable/notes/serialization.html)查看关于保存模型的详细介绍"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 5. 在测试集上测试网络\n",
    "我们在整个训练集上训练了两次网络,但是我们还需要检查网络是否从数据集中学习到东西。\n",
    "\n",
    "我们通过预测神经网络输出的类别标签并根据实际情况进行检测，如果预测正确,我们把该样本添加到正确预测列表。\n",
    "\n",
    "第一步，显示测试集中的图片一遍熟悉图片内容。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "dataiter = iter(testloader)\n",
    "images, labels = dataiter.next()\n",
    "\n",
    "# print images\n",
    "imshow(torchvision.utils.make_grid(images))\n",
    "print('GroundTruth: ', ' '.join('%5s' % classes[labels[j]] for j in range(4)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来，让我们重新加载我们保存的模型(注意:保存和重新加载模型在这里不是必要的，我们只是为了说明如何这样做)："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = Net()\n",
    "net.load_state_dict(torch.load(PATH))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在我们来看看神经网络认为以上图片是什么?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "outputs = net(images)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "输出是10个标签的概率。一个类别的概率越大,神经网络越认为他是这个类别。所以让我们得到最高概率的标签。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, predicted = torch.max(outputs, 1)\n",
    "\n",
    "print('Predicted: ', ' '.join('%5s' % classes[predicted[j]]\n",
    "                              for j in range(4)))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "这结果看起来非常的好。\n",
    "\n",
    "接下来让我们看看网络在整个测试集上的结果如何。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        outputs = net(images)\n",
    "        _, predicted = torch.max(outputs.data, 1)\n",
    "        total += labels.size(0)\n",
    "        correct += (predicted == labels).sum().item()\n",
    "\n",
    "print('Accuracy of the network on the 10000 test images: %d %%' % (\n",
    "    100 * correct / total))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "结果看起来好于偶然，偶然的正确率为10%,似乎网络学习到了一些东西。\n",
    "\n",
    "那在什么类上预测较好，什么类预测结果不好呢？\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_correct = list(0. for i in range(10))\n",
    "class_total = list(0. for i in range(10))\n",
    "with torch.no_grad():\n",
    "    for data in testloader:\n",
    "        images, labels = data\n",
    "        outputs = net(images)\n",
    "        _, predicted = torch.max(outputs, 1)\n",
    "        c = (predicted == labels).squeeze()\n",
    "        for i in range(4):\n",
    "            label = labels[i]\n",
    "            class_correct[label] += c[i].item()\n",
    "            class_total[label] += 1\n",
    "\n",
    "\n",
    "for i in range(10):\n",
    "    print('Accuracy of %5s : %2d %%' % (\n",
    "        classes[i], 100 * class_correct[i] / class_total[i]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来干什么?\n",
    "\n",
    "我们如何在GPU上运行神经网络呢?\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 在GPU上训练\n",
    "你是如何把一个Tensor转换GPU上,你就如何把一个神经网络移动到GPU上训练。这个操作会递归遍历有所模块,并将其参数和缓冲区转换为CUDA张量。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "# Assume that we are on a CUDA machine, then this should print a CUDA device:\n",
    "#假设我们有一台CUDA的机器，这个操作将显示CUDA设备。\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "接下来假设我们有一台CUDA的机器，然后这些方法将递归遍历所有模块并将其参数和缓冲区转换为CUDA张量："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "net.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "请记住，你也必须在每一步中把你的输入和目标值转换到GPU上:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs, labels = inputs.to(device), labels.to(device)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "为什么我们没注意到GPU的速度提升很多?那是因为网络非常的小。\n",
    "\n",
    "#### 实践:\n",
    "尝试增加你的网络的宽度(第一个`nn.Conv2d`的第2个参数, 第二个`nn.Conv2d`的第一个参数,他们需要是相同的数字),看看你得到了什么样的加速。\n",
    "\n",
    "#### 实现的目标:\n",
    "* 深入了解了PyTorch的张量库和神经网络\n",
    "* 训练了一个小网络来分类图片"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 在多GPU上训练\n",
    "如果你希望使用所有GPU来更大的加快速度,请查看选读:[数据并行](https://pytorch.org/tutorials/beginner/blitz/data_parallel_tutorial.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 接下来做什么?\n",
    "* 训练神经网络玩电子游戏\n",
    "* 在ImageNet上训练最好的ResNet\n",
    "* 使用对抗生成网络来训练一个人脸生成器\n",
    "* 使用LSTM网络训练一个字符级的语言模型\n",
    "* 更多示例\n",
    "* 更多教程\n",
    "* 在论坛上讨论PyTorch\n",
    "* 在Slack上与其他用户聊天"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
